from pathlib import Path
from typing import Annotated, Literal, TypeAlias

from pydantic import BaseModel, Field, model_validator

from prime_rl.utils.config import ClientConfig, LogConfig, ModelConfig, WandbMonitorConfig
from prime_rl.utils.pydantic_config import BaseConfig, BaseSettings


class SamplingConfig(BaseConfig):
    """Configures how tokens are sampled from the model for training. Largely follows the vLLM sampling parameters."""

    temperature: Annotated[
        float,
        Field(
            ge=0,
            description="Scales the output probability distribution. Lower values => more deterministic, higher values => more random. If 0, will sample greedily.",
        ),
    ] = 1.0

    repetition_penalty: Annotated[
        float,
        Field(
            ge=0,
            description="Penalty for repeating tokens. Values > 1.0 discourage repetition, values < 1.0 encourage repetition, and 1.0 means no penalty.",
        ),
    ] = 1.0

    max_tokens: Annotated[
        int | None,
        Field(
            description="Maximum number of output tokens to generate per turn. If None, will generate until maximum context length or EOS token is hit.",
        ),
    ] = None

    min_tokens: Annotated[
        int,
        Field(
            ge=0,
            description="Minimum number of output tokens to generate per sequence.",
        ),
    ] = 0

    seed: Annotated[
        int | None,
        Field(
            description="Random seed to use for sampling. If None, no seeding is used.",
        ),
    ] = None


class EvalSamplingConfig(BaseConfig):
    """Configures how tokens are sampled from the model for evaluation. Largely follows the vLLM sampling parameters."""

    temperature: Annotated[
        float | None,
        Field(
            ge=0,
            description="Scales the output probability distribution. Lower values => more deterministic, higher values => more random. If 0, will sample greedily. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None

    repetition_penalty: Annotated[
        float | None,
        Field(
            ge=0,
            description="Penalty for repeating tokens. Values > 1.0 discourage repetition, values < 1.0 encourage repetition, and 1.0 means no penalty. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None

    top_p: Annotated[
        float | None,
        Field(
            description="Cumulative probability of the top tokens to consider. If 1, all tokens are considered. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None

    top_k: Annotated[
        int | None,
        Field(
            description="Number of top tokens to consider. If -1, all tokens are considered. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None

    min_p: Annotated[
        float | None,
        Field(
            description="Minimum probability for a token to be considered, relative to the probability of the most likely token. If 0, all tokens are considered. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None

    max_tokens: Annotated[
        int | None,
        Field(
            description="Maximum number of output tokens to generate per turn. If None, will generate until maximum context length or EOS token is hit.",
        ),
    ] = None

    min_tokens: Annotated[
        int | None,
        Field(
            description="Minimum number of output tokens to generate per sequence. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None

    reasoning_effort: Annotated[
        Literal["minimal", "low", "medium", "high"] | None,
        Field(
            description="Constrains effort on reasoning for reasoning models. Currently supported values are minimal, low, medium, and high. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None

    seed: Annotated[
        int | None,
        Field(
            description="Random seed to use for sampling. If None, no seeding is used. Defaults to None, which means we fall back to the inference server's default value.",
        ),
    ] = None


class EvalSaveDiskConfig(BaseConfig):
    """Configures how to save the eval results to disk."""

    path: Annotated[
        Path | None,
        Field(
            description="The path to save the eval results to. If None, will default to <output_dir>/evals/<step_path>/<env_id> for online evals and the verifiers default for offline evals."
        ),
    ] = None


class EvalSaveHFConfig(BaseConfig):
    """Configures how to save the eval results to HF."""

    dataset_name: Annotated[
        str | None,
        Field(
            description="The name of the HF dataset to save the eval results to. If None, will auto-generate a name."
        ),
    ] = None

    dataset_subset: Annotated[
        str | None,
        Field(
            description="The subset name of the HF dataset to save the evaluation results. If None, will default to the environment ID.",
        ),
    ] = None

    dataset_split: Annotated[
        str | None,
        Field(
            description="The split name of the HF dataset to save the evaluation results. If None, will default to 'evals'.",
        ),
    ] = None

    private: Annotated[
        bool,
        Field(description="Whether to save the eval results to a private HF dataset."),
    ] = False


class EvalSaveConfig(BaseConfig):
    disk: EvalSaveDiskConfig | None = None
    hf: EvalSaveHFConfig | None = None
    env_hub: Annotated[
        bool,
        Field(
            description="Whether to push eval results to Prime Environment Hub. Automatically pushes all evaluated environments. Requires PRIME_API_KEY and authorization for the environments."
        ),
    ] = False


class EnvConfig(BaseConfig):
    """Configures an environment for training."""

    id: Annotated[str, Field(description="ID of the environment to use.")] = "reverse-text"
    args: Annotated[dict, Field(description="Arguments to pass to the environment.")] = {}
    name: Annotated[str | None, Field(description="Name of the environment to use.")] = None


class EvalEnvConfig(EnvConfig):
    """Configures an environment for evaluation."""

    num_examples: Annotated[
        int | None,
        Field(
            description="Number of examples to evaluate per environment. If not set, will use 'num_examples' from main config."
        ),
    ] = None
    rollouts_per_example: Annotated[
        int | None,
        Field(
            description="Number of samples to generate per example for each environment. If not set, will use 'rollouts_per_example' from main config."
        ),
    ] = None


class ValConfig(BaseConfig):
    """Configures the validation of the model."""

    num_examples: Annotated[
        int, Field(ge=1, description="Number of examples to use for validation. If -1, will use all examples.")
    ] = 16
    rollouts_per_example: Annotated[
        int, Field(ge=1, description="Number of samples to generate per example for validation.")
    ] = 1
    interval: Annotated[int, Field(description="Interval at which to validate the model.")] = 10


class EvalConfig(BaseConfig):
    """Configures evaluation using verifiers environments."""

    env: list[EvalEnvConfig] = [EvalEnvConfig()]
    sampling: EvalSamplingConfig = Field(
        default_factory=EvalSamplingConfig,
        description="Shared sampling configuration for evals; can differ from training sampling.",
    )
    save: EvalSaveConfig = Field(
        default_factory=EvalSaveConfig,
        description="Configures how to save the eval results.",
    )
    num_examples: Annotated[int, Field(description="Number of examples to evaluate per environment.")] = -1
    rollouts_per_example: Annotated[
        int, Field(ge=1, description="Number of samples to generate per example for each environment.")
    ] = 1


class OnlineEvalConfig(EvalConfig):
    """Configures online evaluation."""

    interval: Annotated[
        int,
        Field(
            ge=1,
            description="Interval at which to evaluate the model.",
        ),
    ] = 100

    eval_base_model: Annotated[
        bool,
        Field(
            description="Whether to evaluate the base model we are training on.",
        ),
    ] = True


class CheckpointConfig(BaseConfig):
    """Configures checkpointing the orchestrator."""

    interval: Annotated[int | None, Field(ge=1, description="Interval at which to save the checkpoint.")] = None

    resume_step: Annotated[
        int | None,
        Field(
            ge=-1,
            description="Step to resume orchestrator from. If None, will start from scratch. If -1, will restart from latest checkpoint available.",
        ),
    ] = None

    keep: Annotated[
        int | None,
        Field(
            ge=1,
            description="Keep at most this many recent step checkpoints on disk. If None, never clean old checkpoints.",
        ),
    ] = None

    skip_progress: Annotated[
        bool,
        Field(
            description="Whether to skip loading the progress from checkpoint.",
        ),
    ] = False

    skip_buffer: Annotated[
        bool,
        Field(
            description="Whether to skip loading the buffer from checkpoint.",
        ),
    ] = False

    buffer_path: Annotated[
        Path | None,
        Field(
            description="The path to load buffer state (metadata and rollouts) from. If None, will start with an empty buffer. The buffer state is saved at <ckpt_dir>/step_<step>/orchestrator/buffer.",
        ),
    ] = None


class BufferConfig(BaseConfig):
    """Configures the buffer for the orchestrator."""

    seed: Annotated[
        int | None,
        Field(
            description="Random seed to use for the buffer. If set, the sampling from the buffer will be deterministic.",
        ),
    ] = None

    easy_fraction: Annotated[
        float,
        Field(
            ge=0,
            le=1,
            description="Fraction of the batch that should consist of easy samples.",
        ),
    ] = 0.0

    hard_fraction: Annotated[
        float,
        Field(
            ge=0,
            le=1,
            description="Fraction of the batch that should consist of hard samples.",
        ),
    ] = 0.0

    easy_threshold: Annotated[
        float | None,
        Field(
            description="Threshold for easy difficulty classification. If average reward >= this threshold, mark as easy.",
        ),
    ] = None

    hard_threshold: Annotated[
        float | None,
        Field(
            description="Threshold for hard difficulty classification. If average reward <= this threshold, mark as hard.",
        ),
    ] = None

    online_difficulty_filtering: Annotated[
        bool,
        Field(
            description="Whether to filter rollouts based on their average reward. If True, rollouts with average reward == 0.0 will be marked as hard and rollouts with average reward == 1.0 will be marked as easy.",
        ),
    ] = False


class AdvantageConfig(BaseConfig):
    length_weighted_mean: bool = False


class FileSystemWeightBroadcastConfig(BaseModel):
    """Configures the filesystem weight broadcast."""

    type: Literal["filesystem"] = "filesystem"


class NCCLWeightBroadcastConfig(BaseModel):
    """Configures the NCCL weight broadcast."""

    type: Literal["nccl"] = "nccl"

    host: Annotated[str, Field(description="The host to use for the NCCL broadcast.")] = "localhost"
    port: Annotated[int, Field(description="The port to use for the NCCL broadcast.")] = 29501
    timeout: Annotated[int, Field(description="The timeout in seconds to use for the NCCL broadcast.")] = 1200


WeightBroadcastConfigType: TypeAlias = FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig


class EnvMixConfig(BaseModel):
    """Configures the mixing of environments."""

    strategy: Literal["interleave", "concatenate"] = "interleave"
    probabilities: Annotated[list[float] | None, Field(description="Probabilities to use for each environment.")] = None
    stopping_strategy: Annotated[
        Literal["first_exhausted", "all_exhausted"],
        Field(description="Stopping strategy to use for interleaving environment datasets."),
    ] = "all_exhausted"
    seed: Annotated[int | None, Field(description="Random seed to use for the environment mixing.")] = None


class OrchestratorConfig(BaseSettings):
    """Configures the orchestrator for RL training."""

    # The OAI client configuration
    client: ClientConfig = ClientConfig()

    # The model configuration
    model: ModelConfig = ModelConfig()

    # The sampling configuration
    sampling: SamplingConfig = SamplingConfig()

    # The environment mixing configuration
    env_mix: EnvMixConfig = EnvMixConfig()

    # The environment configuration
    env: list[EnvConfig] = [EnvConfig()]

    # The evaluation configuration
    eval: OnlineEvalConfig | None = None

    # Data buffer configuration
    buffer: BufferConfig = BufferConfig()

    # The advantage configuration
    advantage: AdvantageConfig | None = AdvantageConfig()

    # The logging configuration
    log: LogConfig = LogConfig()

    # The wandb configuration
    wandb: WandbMonitorConfig | None = None

    # The checkpoint configuration
    ckpt: CheckpointConfig | None = None

    # The validation configuration
    val: ValConfig | None = None

    weight_broadcast: Annotated[WeightBroadcastConfigType, Field(discriminator="type")] = (
        FileSystemWeightBroadcastConfig()
    )

    output_dir: Annotated[
        Path,
        Field(
            description="Directory to write outputs to. Will be populated with checkpoints, weights, rollouts and logs as subdirectories. Should be set to a persistent directory with enough disk space. This value should be distinct across experiments running on a single node. See the README for more details."
        ),
    ] = Path("outputs")

    max_concurrent: Annotated[
        int | None,
        Field(
            description="Maximum number of concurrent rollouts to generate and score. Will create a global semaphore and pass to verifiers Environment. If None, will not limit concurrency.",
        ),
    ] = None

    batch_size: Annotated[int, Field(ge=1, description="Number of samples to train on per step.")] = 128

    oversampling_factor: Annotated[
        float,
        Field(
            ge=1,
            description="Factor by which to oversample the batch. Will lead to more in-flight group rollout requests at the same time.",
        ),
    ] = 1.0

    rollouts_per_example: Annotated[
        int,
        Field(
            ge=1,
            description="Number of output sequences to return per example during training.",
        ),
    ] = 1

    seq_len: Annotated[
        int,
        Field(
            description="Sequence length to use for training. If a sample is shorter than this, it will be padded. If a sequence is longer than this, it will be truncated.",
        ),
    ] = 2048

    mask_env_responses: Annotated[
        bool,
        Field(
            description="Whether to mask environment responses from the loss.",
        ),
    ] = True

    mask_truncated_completions: Annotated[
        bool,
        Field(
            description="Whether to mask truncated completions from the loss.",
        ),
    ] = False

    zero_truncated_completions: Annotated[
        bool,
        Field(
            description="Whether to override reward scores with 0 for truncated completions.",
        ),
    ] = False

    # TODO(Mika): This should be automatic from the number of ZMQ connections
    num_train_workers: Annotated[
        int,
        Field(default=1, ge=1, description="Number of training workers to use for training."),
    ] = 1

    max_steps: Annotated[
        int | None,
        Field(
            description="Maximum number of training steps to run. If None, will run indefinitely.",
        ),
    ] = None

    max_off_policy_steps: Annotated[
        int,
        Field(
            ge=0,
            description="Maximum number of policies that are allowed to generate a single rollout. Rollouts that are generated from more than `max_off_policy_steps` steps ahead of training will be discarded. Higher values yield better throughput, but lead to more off-policyness in training.",
        ),
    ] = 8

    max_async_level: Annotated[
        int,
        Field(
            ge=0,
            description="Maximum number of steps the inference can be ahead of training. If 0, will degenerate to synchronous on-policy RL. If >=1, training and inference will be overlapped.",
        ),
    ] = 1

    strict_async_level: Annotated[
        bool,
        Field(
            description="Whether to strictly enforce the max async level. If True, will always ensure that the policy used for generating rollouts is exactly `max_async_level` steps ahead of training. If False, any policy that is at most `max_async_level` steps ahead of training is allowed, i.e. we always use the latest available policy.",
        ),
    ] = True

    bench: Annotated[
        bool,
        Field(
            description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5, max async level to ~infinity and disable W&B.",
        ),
    ] = False

    seed: Annotated[int | None, Field(description="Random seed for the orchestrator.")] = 42

    @model_validator(mode="after")
    def nccl_max_async_level(self):
        if self.weight_broadcast.type == "nccl":
            if not self.max_async_level == 1:
                raise ValueError("max_async_level must be 1 for NCCL broadcast")
        return self

    @model_validator(mode="after")
    def validate_batch_size(self):
        if self.batch_size % self.rollouts_per_example != 0:
            raise ValueError("Batch size must be divisible by the number of samples per problem")
        return self

    @model_validator(mode="after")
    def auto_setup_bench(self):
        if self.bench:
            self.max_steps = 4  # Run for 1 warmup step + 3 evaluation steps
            self.max_async_level = int(1e9)  # Never wait for RL weight checkpoints

            # Disable evaluation
            self.eval = None
            if self.wandb:
                self.wandb.log_extras = None

        return self
